{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "initial_id",
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-04-25T02:46:35.421332500Z",
     "start_time": "2024-04-25T02:46:26.872001800Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch\n",
    "\n",
    "# 构建LoraLayer\n",
    "class LoRALayer(nn.Module):\n",
    "    def __init__(self, in_dim, out_dim, rank,  alpha):\n",
    "        super().__init__()\n",
    "        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())\n",
    "        self.A = nn.Parameter(torch.rand(in_dim, rank)*std_dev)\n",
    "        self.B = nn.Parameter(torch.zeros(rank, out_dim))\n",
    "        self.alpha = alpha\n",
    "    def forward(self, x):\n",
    "        x = self.alpha * (x @ self.A @ self.B)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "class LinearWithLoRA(nn.Module):\n",
    "    def __init__(self, linear, rank, alpha):\n",
    "        super().__init__()\n",
    "        self.linear = linear\n",
    "        self.lora = LoRALayer(\n",
    "            linear.in_features,\n",
    "            linear.out_features,\n",
    "            rank,\n",
    "            alpha\n",
    "        )\n",
    "    def forward(self,x):\n",
    "        return self.linear(x) + self.lora(x)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-25T02:46:35.443764100Z",
     "start_time": "2024-04-25T02:46:35.423327Z"
    }
   },
   "id": "5043ad65d0e71bce"
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original output: tensor([[0.6639, 0.4487]], grad_fn=<AddmmBackward0>)\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "layer = nn.Linear(10, 2)\n",
    "x = torch.randn((1, 10))\n",
    "\n",
    "print(\"Original output:\", layer(x))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-25T02:46:39.588691800Z",
     "start_time": "2024-04-25T02:46:38.957326500Z"
    }
   },
   "id": "ed917b2c22108634"
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LoRA output: tensor([[0.6639, 0.4487]], grad_fn=<AddBackward0>)\n"
     ]
    }
   ],
   "source": [
    "layer_lora_1 = LinearWithLoRA(layer, rank=2, alpha=4)\n",
    "print(\"LoRA output:\", layer_lora_1(x))"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-25T02:46:41.038471900Z",
     "start_time": "2024-04-25T02:46:40.956656600Z"
    }
   },
   "id": "7d88cf835c6a2479"
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [],
   "source": [
    "class TestMLP(nn.Module):\n",
    "    def __init__(self, num_features, num_hidden1, num_hidden2, num_class):\n",
    "        super().__init__()\n",
    "        self.layers = nn.Sequential(\n",
    "            nn.Linear(num_features, num_hidden1),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(num_hidden1, num_hidden2),\n",
    "            nn.ReLU(),\n",
    "            \n",
    "            nn.Linear(num_hidden2, num_class)\n",
    "        )\n",
    "    def forward(self, x):\n",
    "        x = self.layers(x)\n",
    "        return x        "
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-25T02:46:41.953286400Z",
     "start_time": "2024-04-25T02:46:41.937329600Z"
    }
   },
   "id": "32464af84070e2ec"
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TestMLP(\n",
      "  (layers): Sequential(\n",
      "    (0): Linear(in_features=64, out_features=32, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Linear(in_features=32, out_features=64, bias=True)\n",
      "    (3): ReLU()\n",
      "    (4): Linear(in_features=64, out_features=3, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "num_features = 64\n",
    "num_hidden1 = 32\n",
    "num_hidden2=64\n",
    "num_class=3\n",
    "\n",
    "model = TestMLP(\n",
    "    num_features=num_features,\n",
    "    num_hidden1=num_hidden1,\n",
    "    num_hidden2=num_hidden2,\n",
    "    num_class=num_class\n",
    ")\n",
    "\n",
    "print(model)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-25T03:00:15.820683900Z",
     "start_time": "2024-04-25T03:00:15.802611Z"
    }
   },
   "id": "2713442117d0ea52"
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "name: layers\n",
      "module Sequential(\n",
      "  (0): Linear(in_features=64, out_features=32, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=32, out_features=64, bias=True)\n",
      "  (3): ReLU()\n",
      "  (4): Linear(in_features=64, out_features=3, bias=True)\n",
      ")\n",
      "name: 0\n",
      "module Linear(in_features=64, out_features=32, bias=True)\n",
      "name: 1\n",
      "module ReLU()\n",
      "name: 2\n",
      "module Linear(in_features=32, out_features=64, bias=True)\n",
      "name: 3\n",
      "module ReLU()\n",
      "name: 4\n",
      "module Linear(in_features=64, out_features=3, bias=True)\n"
     ]
    }
   ],
   "source": [
    "def convert_layers(model):\n",
    "    for name, module in model.named_children():\n",
    "        print('name:',name)\n",
    "        print('module', module)\n",
    "        if isinstance(module, nn.Linear):\n",
    "            setattr(model, name, LinearWithLoRA(module, rank=4, alpha=8))  # 使用你的参数替换...\n",
    "        else:\n",
    "            convert_layers(module)\n",
    "convert_layers(model)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-25T03:00:17.116895100Z",
     "start_time": "2024-04-25T03:00:17.102931Z"
    }
   },
   "id": "7e986ff890ab6fd1"
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TestMLP(\n",
      "  (layers): Sequential(\n",
      "    (0): LinearWithLoRA(\n",
      "      (linear): Linear(in_features=64, out_features=32, bias=True)\n",
      "      (lora): LoRALayer()\n",
      "    )\n",
      "    (1): ReLU()\n",
      "    (2): LinearWithLoRA(\n",
      "      (linear): Linear(in_features=32, out_features=64, bias=True)\n",
      "      (lora): LoRALayer()\n",
      "    )\n",
      "    (3): ReLU()\n",
      "    (4): LinearWithLoRA(\n",
      "      (linear): Linear(in_features=64, out_features=3, bias=True)\n",
      "      (lora): LoRALayer()\n",
      "    )\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(model)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-25T02:50:32.515819Z",
     "start_time": "2024-04-25T02:50:32.493877300Z"
    }
   },
   "id": "5d4507ca913c0a3"
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "layers.0.linear.weight:True\n",
      "layers.0.linear.bias:True\n",
      "layers.0.lora.A:True\n",
      "layers.0.lora.B:True\n",
      "layers.2.linear.weight:True\n",
      "layers.2.linear.bias:True\n",
      "layers.2.lora.A:True\n",
      "layers.2.lora.B:True\n",
      "layers.4.linear.weight:True\n",
      "layers.4.linear.bias:True\n",
      "layers.4.lora.A:True\n",
      "layers.4.lora.B:True\n"
     ]
    }
   ],
   "source": [
    "for name, param in model.named_parameters():\n",
    "    print(f'{name}:{param.requires_grad}')"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-25T02:51:01.422033100Z",
     "start_time": "2024-04-25T02:51:01.401092700Z"
    }
   },
   "id": "5035ea98fd8fd2ad"
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "model.layers[0] = LinearWithLoRA(model.layers[0], rank=4, alpha=8)\n",
    "model.layers[2] = LinearWithLoRA(model.layers[2], rank=4, alpha=8)\n",
    "model.layers[4] = LinearWithLoRA(model.layers[4], rank=4, alpha=8)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-22T02:34:13.357760700Z",
     "start_time": "2024-04-22T02:34:13.341757600Z"
    }
   },
   "id": "f2112128ba1b8f56"
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TestMLP(\n",
      "  (layers): Sequential(\n",
      "    (0): LinearWithLoRA(\n",
      "      (linear): Linear(in_features=64, out_features=32, bias=True)\n",
      "      (lora): LoRALayer()\n",
      "    )\n",
      "    (1): ReLU()\n",
      "    (2): LinearWithLoRA(\n",
      "      (linear): Linear(in_features=32, out_features=64, bias=True)\n",
      "      (lora): LoRALayer()\n",
      "    )\n",
      "    (3): ReLU()\n",
      "    (4): LinearWithLoRA(\n",
      "      (linear): Linear(in_features=64, out_features=3, bias=True)\n",
      "      (lora): LoRALayer()\n",
      "    )\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "print(model)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-22T02:43:21.081936200Z",
     "start_time": "2024-04-22T02:43:21.051865600Z"
    }
   },
   "id": "13fd3d2e85230f4c"
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "outputs": [
    {
     "data": {
      "text/plain": "tensor([[0.4290, 0.4223, 0.4418, 0.4131, 0.3538, 0.4257, 0.3908, 0.3639, 0.4376,\n         0.4444, 0.4105, 0.4212, 0.4105, 0.4042, 0.3824, 0.3944, 0.3684, 0.4318,\n         0.4080, 0.3716, 0.3509, 0.4130, 0.4303, 0.4321, 0.4073, 0.4147, 0.4601,\n         0.4353, 0.4566, 0.3958, 0.4717, 0.4171, 0.3877, 0.4053, 0.4421, 0.4552,\n         0.3942, 0.4467, 0.4286, 0.4359, 0.3872, 0.3767, 0.4176, 0.3911, 0.3775,\n         0.5211, 0.4224, 0.3992, 0.4345, 0.3712, 0.3642, 0.3961, 0.4215, 0.3368,\n         0.4231, 0.3695, 0.4020, 0.4400, 0.3733, 0.4158, 0.4570, 0.3898, 0.4052,\n         0.4443]])"
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.layers[0].linear.weight.norm(p=2, dim=0, keepdim=True)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-22T02:48:44.342432400Z",
     "start_time": "2024-04-22T02:48:44.261956600Z"
    }
   },
   "id": "8686a69aeb370a5f"
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LinearWithLoRA(\n",
      "  (linear): Linear(in_features=64, out_features=32, bias=True)\n",
      "  (lora): LoRALayer()\n",
      ")\n",
      "Parameter containing:\n",
      "tensor([[-0.1181, -0.0842, -0.0498,  ..., -0.0536,  0.1188, -0.0621],\n",
      "        [ 0.0551,  0.0490,  0.0349,  ...,  0.0645,  0.0883, -0.1213],\n",
      "        [-0.1061, -0.1217,  0.0472,  ..., -0.0622,  0.1077, -0.1028],\n",
      "        ...,\n",
      "        [-0.0842,  0.0930, -0.0598,  ..., -0.0115,  0.0299, -0.0728],\n",
      "        [ 0.1118,  0.0820,  0.0991,  ...,  0.0118, -0.0881,  0.0676],\n",
      "        [ 0.0783, -0.0778,  0.0060,  ...,  0.0182,  0.0307, -0.0319]],\n",
      "       requires_grad=True)\n",
      "Parameter containing:\n",
      "tensor([-0.0438,  0.0379,  0.0179,  0.0866, -0.0935,  0.1226, -0.1163, -0.0920,\n",
      "        -0.0220, -0.1086,  0.0843, -0.0219,  0.0048,  0.0405,  0.0141,  0.1067,\n",
      "         0.0897,  0.0425, -0.1199,  0.0470, -0.0206,  0.0212,  0.0958, -0.0589,\n",
      "        -0.0080, -0.1033,  0.0453, -0.0274, -0.0461,  0.0069,  0.0059,  0.0228],\n",
      "       requires_grad=True)\n",
      "Parameter containing:\n",
      "tensor([[3.4769e-01, 1.4958e-01, 4.6664e-01, 1.3308e-01],\n",
      "        [2.5660e-01, 5.3322e-02, 1.1694e-01, 3.3659e-01],\n",
      "        [2.2197e-01, 4.0808e-01, 2.7168e-01, 3.7301e-02],\n",
      "        [3.2966e-01, 4.3251e-01, 1.8411e-01, 1.6940e-01],\n",
      "        [4.1231e-02, 4.4806e-01, 8.1315e-02, 2.8213e-01],\n",
      "        [4.5784e-01, 2.8490e-01, 3.0800e-02, 2.3260e-01],\n",
      "        [2.3099e-02, 2.3788e-01, 4.7768e-01, 2.6270e-01],\n",
      "        [3.8744e-01, 1.9199e-02, 3.8960e-01, 2.9224e-01],\n",
      "        [3.4128e-01, 1.5870e-01, 4.4300e-02, 1.6574e-01],\n",
      "        [2.4914e-01, 1.5658e-02, 1.0003e-01, 3.5678e-02],\n",
      "        [4.8575e-01, 3.7997e-01, 1.7147e-01, 2.5824e-01],\n",
      "        [2.5086e-03, 2.1639e-01, 2.8741e-01, 8.2283e-02],\n",
      "        [1.9084e-01, 2.8903e-01, 4.0755e-01, 2.8055e-01],\n",
      "        [3.9756e-01, 4.0692e-01, 2.5932e-01, 2.3080e-01],\n",
      "        [4.2067e-01, 2.7529e-01, 3.3700e-02, 4.2785e-01],\n",
      "        [1.4143e-01, 2.5815e-01, 1.8906e-01, 1.7065e-02],\n",
      "        [3.6510e-01, 4.0044e-01, 2.0031e-01, 2.6841e-02],\n",
      "        [8.4753e-02, 2.2603e-01, 4.9669e-01, 4.8972e-02],\n",
      "        [3.0481e-01, 2.4413e-01, 2.4505e-01, 2.5033e-01],\n",
      "        [3.4420e-01, 4.3322e-01, 4.4987e-01, 1.9115e-02],\n",
      "        [3.5904e-01, 4.4744e-01, 4.6405e-02, 5.9957e-02],\n",
      "        [4.5081e-02, 4.6698e-01, 8.2647e-02, 3.1697e-01],\n",
      "        [1.3002e-03, 2.7410e-01, 2.0225e-01, 1.4322e-01],\n",
      "        [2.4356e-01, 2.9629e-01, 1.5537e-01, 1.9653e-02],\n",
      "        [4.2938e-01, 2.4280e-01, 2.3219e-01, 2.7249e-01],\n",
      "        [1.5461e-01, 2.3969e-01, 1.2364e-02, 2.7849e-02],\n",
      "        [3.1201e-01, 3.5877e-01, 2.0191e-01, 2.6144e-01],\n",
      "        [3.6163e-01, 5.5893e-02, 4.8385e-01, 4.5235e-02],\n",
      "        [2.3547e-01, 2.1719e-01, 8.5094e-02, 2.3156e-01],\n",
      "        [1.8574e-01, 6.2233e-03, 3.8006e-01, 9.5574e-02],\n",
      "        [4.1799e-01, 1.9825e-01, 2.2901e-01, 1.1899e-01],\n",
      "        [2.9523e-01, 1.9308e-01, 1.1139e-01, 2.2789e-01],\n",
      "        [1.2930e-01, 2.4445e-01, 3.0867e-01, 4.5769e-01],\n",
      "        [4.8291e-01, 3.8409e-01, 5.4396e-02, 7.0665e-02],\n",
      "        [3.1351e-01, 1.6483e-01, 3.5857e-01, 1.5833e-01],\n",
      "        [4.9250e-01, 4.7023e-02, 4.6383e-01, 2.6428e-01],\n",
      "        [8.4805e-02, 2.8030e-02, 3.1227e-01, 1.5006e-01],\n",
      "        [3.0071e-01, 2.1502e-01, 1.5701e-01, 1.1489e-01],\n",
      "        [2.6912e-01, 1.1310e-01, 4.2946e-01, 3.1327e-01],\n",
      "        [1.9738e-01, 2.2724e-01, 2.2680e-01, 3.3938e-01],\n",
      "        [8.7045e-02, 1.0418e-01, 2.0969e-01, 4.8096e-01],\n",
      "        [5.2581e-02, 3.6308e-01, 3.1698e-01, 2.7035e-01],\n",
      "        [3.1743e-01, 3.0595e-01, 4.0074e-01, 2.4645e-02],\n",
      "        [4.9827e-01, 1.9090e-01, 2.3085e-01, 1.9934e-01],\n",
      "        [1.3454e-01, 4.2285e-01, 1.4789e-01, 2.1503e-01],\n",
      "        [4.2019e-01, 3.4932e-01, 3.4867e-01, 9.2309e-02],\n",
      "        [3.1572e-01, 3.1788e-01, 3.5414e-01, 7.5776e-03],\n",
      "        [3.5989e-01, 3.0097e-01, 4.2212e-01, 4.6567e-01],\n",
      "        [4.0863e-01, 4.8323e-01, 2.5105e-04, 1.9984e-01],\n",
      "        [1.2298e-02, 4.4716e-01, 8.2384e-02, 4.8091e-01],\n",
      "        [7.3011e-02, 2.0120e-01, 2.8994e-01, 2.5072e-01],\n",
      "        [2.4913e-01, 4.6542e-01, 1.6503e-01, 4.0151e-01],\n",
      "        [3.4772e-01, 3.1041e-01, 1.2906e-01, 4.3781e-01],\n",
      "        [1.3520e-01, 4.6079e-01, 2.5575e-01, 4.2394e-02],\n",
      "        [4.8827e-02, 9.3357e-02, 4.5437e-01, 3.5113e-01],\n",
      "        [2.9004e-01, 2.1169e-01, 3.2923e-01, 2.5868e-01],\n",
      "        [9.5278e-02, 2.2772e-01, 7.5298e-02, 2.4745e-01],\n",
      "        [8.3639e-02, 4.3569e-01, 4.2180e-01, 2.7622e-01],\n",
      "        [4.5535e-01, 1.8373e-01, 4.9494e-01, 7.2837e-03],\n",
      "        [3.7708e-01, 2.2672e-01, 4.1715e-02, 3.1138e-01],\n",
      "        [4.0578e-01, 2.2975e-01, 4.9528e-01, 4.2176e-01],\n",
      "        [2.3888e-01, 1.6836e-01, 4.1486e-01, 2.4548e-02],\n",
      "        [4.0845e-01, 2.5256e-01, 2.7059e-02, 2.7183e-01],\n",
      "        [4.8519e-01, 4.5048e-01, 4.7933e-01, 2.7883e-01]], requires_grad=True)\n",
      "Parameter containing:\n",
      "tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "         0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "         0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "         0., 0., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
      "         0., 0., 0., 0., 0., 0., 0., 0.]], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "i = 0\n",
    "for child in model.children():\n",
    "    for c in child.children():\n",
    "        print(c)\n",
    "        for i , j in c.named_parameters():\n",
    "            print(j)\n",
    "        break\n",
    "        if isinstance(c, LinearWithLoRA):\n",
    "            print(\"this is lora\")\n",
    "        print('---------')"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-22T02:34:25.585500600Z",
     "start_time": "2024-04-22T02:34:25.494638700Z"
    }
   },
   "id": "8e4ed896ccabf110"
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "layers.0.linear.weight:False\n",
      "layers.0.linear.bias:False\n",
      "layers.0.lora.A:True\n",
      "layers.0.lora.B:True\n",
      "layers.2.linear.weight:False\n",
      "layers.2.linear.bias:False\n",
      "layers.2.lora.A:True\n",
      "layers.2.lora.B:True\n",
      "layers.4.linear.weight:False\n",
      "layers.4.linear.bias:False\n",
      "layers.4.lora.A:True\n",
      "layers.4.lora.B:True\n"
     ]
    }
   ],
   "source": [
    "def freeze_linear_layers(model):\n",
    "    for child in model.children():\n",
    "        if isinstance(child, nn.Linear):\n",
    "            for param in child.parameters():\n",
    "                param.requires_grad = False\n",
    "        else:\n",
    "            # Recursively freeze linear layers in children modules\n",
    "            freeze_linear_layers(child)\n",
    "\n",
    "freeze_linear_layers(model)\n",
    "for name, param in model.named_parameters():\n",
    "    print(f'{name}:{param.requires_grad}')"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-04-25T02:51:07.637072100Z",
     "start_time": "2024-04-25T02:51:07.617565800Z"
    }
   },
   "id": "314d254f8fe625bd"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class LinearWithDoRA(nn.Module):\n",
    "    def __init__(self, linear, rank, alpha):\n",
    "        super().__init__()\n",
    "        self.linear = linear\n",
    "        self.lora = LoRALayer(\n",
    "            linear.in_features, linear.out_features, rank, alpha\n",
    "        )\n",
    "        self.m = nn.Parameter(torch.ones(1, linear.out_features))\n",
    "    \n",
    "    def forward(self, x):\n",
    "        linear_out = self.linear(x)\n",
    "        lora_out = self.lora(x)\n",
    "        lora_out_norm = lora_out / (lora_out.norm(p=2, dim=1, keepdim=True) + 1e-9)\n",
    "        dora_modification = self.m * lora_out_norm\n",
    "        return linear_out + dora_modification   "
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "cfb491bf03210d26"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
